Variable elimination#24
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| Dict with 'input' and 'target' for loss computation. | ||
| """ | ||
| if isinstance(forward_out, dict): | ||
| return {'input': forward_out['log_joint'], 'target': target} |
There was a problem hiding this comment.
we must find a way not to hard-code 'log_joint' here
| Dict with 'preds' and 'target' for metric computation. | ||
| """ | ||
| if isinstance(forward_out, dict): | ||
| return {'preds': forward_out['logits'], 'target': target} |
| # Distribution type groups. | ||
| _BINARY_DISTRIBUTIONS = {Bernoulli, RelaxedBernoulli} | ||
| _CATEGORICAL_DISTRIBUTIONS = {Categorical, OneHotCategorical, RelaxedOneHotCategorical} | ||
| _CONTINUOUS_DISTRIBUTIONS = {Normal, MultivariateNormal, Delta} |
There was a problem hiding this comment.
is this safe to have such a limited pool of distributions here? what would happen if a user asks for a different distribution? would the code raise error?
| target = kwargs['target'] | ||
|
|
||
| # Flat 2D tensor → fallback to BCE (eval inference path) | ||
| if input.ndim == 2: |
There was a problem hiding this comment.
I am not sure of this. why would I need a different behavior for evaluation? a loss should not be aware of this
| return torch.tensor(0.0, device=input.device) | ||
|
|
||
|
|
||
| class JointNLLLoss(nn.Module): |
There was a problem hiding this comment.
can we use a torch class for NLL loss?
this sounds to me too hard coded for the inferences that return log_joint
| super().__init__() | ||
| self._bce_fallback = nn.BCEWithLogitsLoss() | ||
|
|
||
| def forward(self, **kwargs) -> torch.Tensor: |
There was a problem hiding this comment.
the loss signature should be explicit.
Right now, the signature of the dict after filter_output_for_loss should match the forward signature
In this PR the following things have been implemented: